import sys

import torch

sys.path.append("src")
from dataset.dataset import get_loaders
from pruner.utils import WrappedGPT, find_layers, prepare_calibration_input


def wanda_importance(
    model, tokenizer, n_samples, seed, device=torch.device("cuda:0"), dataset_name="c4"
):
    """
    Compute the wanda importance of each layer in the model.
    """
    W_metrics = {}
    use_cache = model.config.use_cache
    model.config.use_cache = False

    dataloader, _ = get_loaders(
        name=dataset_name,
        nsamples=n_samples,
        seed=seed,
        seqlen=model.seqlen,
        tokenizer=tokenizer,
    )

    with torch.no_grad():
        inps, outs, attention_mask, position_ids = prepare_calibration_input(
            model, dataloader, device
        )

    try:
        layers = model.model.layers
    except:
        layers = model.model.decoder.layers
    cnt = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        if (
            f"model.layers.{i}" in model.hf_device_map
        ):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            inps, outs, attention_mask, position_ids = (
                inps.to(dev),
                outs.to(dev),
                attention_mask.to(dev),
                position_ids.to(dev),
            )

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)

            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(n_samples):
            with torch.no_grad():
                if position_ids:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
                else:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                    )[0]
        for h in handles:
            h.remove()

        for name in subset:
            W_metrics[cnt] = (
                (
                    torch.abs(subset[name].weight.data)
                    * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1, -1)))
                )
                .detach()
                .cpu()
            )
            cnt += 1
        for j in range(n_samples):
            with torch.no_grad():
                if position_ids:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
                else:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                    )[0]
        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    return W_metrics
